Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enable StaticCache for assisted generation #34797

Open
wants to merge 51 commits into
base: main
Choose a base branch
from

Conversation

yao-matrix
Copy link

@gante , I implemented a version for this issue: #32946. Pls help comment, and I can iterate, thx.

@yao-matrix yao-matrix marked this pull request as draft November 19, 2024 07:28
Signed-off-by: YAO Matrix <[email protected]>
Signed-off-by: YAO Matrix <[email protected]>
Signed-off-by: YAO Matrix <[email protected]>
Signed-off-by: YAO Matrix <[email protected]>
Signed-off-by: N <[email protected]>
@yao-matrix yao-matrix marked this pull request as ready for review November 20, 2024 08:02
Signed-off-by: N <[email protected]>
@yao-matrix
Copy link
Author

@gante , could you pls take a look? thx

Copy link
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yao-matrix hey, gante is currently on a long vacation so I reviewed the PR for him. Thanks for adding support for this, Super cool work!

I left a few comments and also we'll need tests in tests/generation/test_utils.py file. I guess static cache now works with all types of candidate generators right?

Comment on lines 1744 to 1765
if assistant_model is not None:
assistant_model._get_cache(
cache_implementation=generation_config.cache_implementation,
batch_size=max(generation_config.num_beams, generation_config.num_return_sequences) * batch_size,
max_cache_len=max_cache_length,
device=device,
model_kwargs=model_kwargs,
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, I think it will be called on assistant model when we call assistant.generate() so there is no need. We can only remove self.generation_config.cache_implementation = None in candidate generator

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the thing is: when we leave to let assistant_model.generate which is in get_candiates to call this. since the max_new _tokens will be set to max_new_tokens = min(int(self.num_assistant_tokens), self.generation_config.max_length - new_cur_len - 1) when it's first-time called, so the cache_length will be set to int(self.num_assistant_tokens) + prompt_len, less than the actual needed cache_length max_token_length + prompt_length, and lead to assert out while generation. So, the key here is assistant model's cache length should be same as main model here. And then I can see this function has assistant_model as an argument but not used it, I think it may be here for the cases like this. That's the rational behind.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, i see, that makes sense. Then we can leave cache init here

src/transformers/generation/candidate_generator.py Outdated Show resolved Hide resolved
src/transformers/cache_utils.py Outdated Show resolved Hide resolved
src/transformers/cache_utils.py Outdated Show resolved Hide resolved
Copy link
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! We need some tests and then I am requesting review from the core maintainer, after that we can merge

@yao-matrix
Copy link
Author

@yao-matrix hey, gante is currently on a long vacation so I reviewed the PR for him. Thanks for adding support for this, Super cool work!

I left a few comments and also we'll need tests in tests/generation/test_utils.py file. I guess static cache now works with all types of candidate generators right?

@zucchini-nlp , test_utils CI pass rate is the same before and after this PR, as below. So no regressions are introduced.
before:
=========================== short test summary info ============================
FAILED tests/generation/test_utils.py::GenerationIntegrationTests::test_assisted_decoding_encoder_decoder_shared_encoder
FAILED tests/generation/test_utils.py::GenerationIntegrationTests::test_assisted_decoding_num_assistant_tokens_heuristic_schedule
FAILED tests/generation/test_utils.py::GenerationIntegrationTests::test_assisted_generation_early_exit
FAILED tests/generation/test_utils.py::GenerationIntegrationTests::test_custom_logits_processor
FAILED tests/generation/test_utils.py::GenerationIntegrationTests::test_default_max_length_warning
FAILED tests/generation/test_utils.py::GenerationIntegrationTests::test_eos_token_id_int_and_list_beam_search
FAILED tests/generation/test_utils.py::GenerationIntegrationTests::test_eos_token_id_int_and_list_top_k_top_sampling
FAILED tests/generation/test_utils.py::GenerationIntegrationTests::test_generate_compile_fullgraph_tiny
FAILED tests/generation/test_utils.py::GenerationIntegrationTests::test_generated_length_assisted_generation
FAILED tests/generation/test_utils.py::GenerationIntegrationTests::test_max_new_tokens_encoder_decoder
FAILED tests/generation/test_utils.py::GenerationIntegrationTests::test_min_length_if_input_embeds
FAILED tests/generation/test_utils.py::GenerationIntegrationTests::test_model_kwarg_assisted_decoding_decoder_only
FAILED tests/generation/test_utils.py::GenerationIntegrationTests::test_model_kwarg_assisted_decoding_encoder_decoder
FAILED tests/generation/test_utils.py::GenerationIntegrationTests::test_model_kwarg_encoder_signature_filtering
FAILED tests/generation/test_utils.py::GenerationIntegrationTests::test_prepare_inputs_for_generation_decoder_llm
FAILED tests/generation/test_utils.py::GenerationIntegrationTests::test_speculative_decoding_equals_regular_decoding
FAILED tests/generation/test_utils.py::GenerationIntegrationTests::test_stop_sequence_stopping_criteria
====== 17 failed, 51 passed, 19 skipped, 13 warnings in 133.78s (0:02:13) ======

after:
=========================== short test summary info ============================
FAILED tests/generation/test_utils.py::GenerationIntegrationTests::test_assisted_decoding_encoder_decoder_shared_encoder
FAILED tests/generation/test_utils.py::GenerationIntegrationTests::test_assisted_decoding_num_assistant_tokens_heuristic_schedule
FAILED tests/generation/test_utils.py::GenerationIntegrationTests::test_assisted_generation_early_exit
FAILED tests/generation/test_utils.py::GenerationIntegrationTests::test_custom_logits_processor
FAILED tests/generation/test_utils.py::GenerationIntegrationTests::test_default_max_length_warning
FAILED tests/generation/test_utils.py::GenerationIntegrationTests::test_eos_token_id_int_and_list_beam_search
FAILED tests/generation/test_utils.py::GenerationIntegrationTests::test_eos_token_id_int_and_list_top_k_top_sampling
FAILED tests/generation/test_utils.py::GenerationIntegrationTests::test_generate_compile_fullgraph_tiny
FAILED tests/generation/test_utils.py::GenerationIntegrationTests::test_generated_length_assisted_generation
FAILED tests/generation/test_utils.py::GenerationIntegrationTests::test_max_new_tokens_encoder_decoder
FAILED tests/generation/test_utils.py::GenerationIntegrationTests::test_min_length_if_input_embeds
FAILED tests/generation/test_utils.py::GenerationIntegrationTests::test_model_kwarg_assisted_decoding_decoder_only
FAILED tests/generation/test_utils.py::GenerationIntegrationTests::test_model_kwarg_assisted_decoding_encoder_decoder
FAILED tests/generation/test_utils.py::GenerationIntegrationTests::test_model_kwarg_encoder_signature_filtering
FAILED tests/generation/test_utils.py::GenerationIntegrationTests::test_prepare_inputs_for_generation_decoder_llm
FAILED tests/generation/test_utils.py::GenerationIntegrationTests::test_speculative_decoding_equals_regular_decoding
FAILED tests/generation/test_utils.py::GenerationIntegrationTests::test_stop_sequence_stopping_criteria
====== 17 failed, 51 passed, 19 skipped, 13 warnings in 133.78s (0:02:13) ======

@yao-matrix
Copy link
Author

LGTM! We need some tests and then I am requesting review from the core maintainer, after that we can merge

thx for reviewing.

@zucchini-nlp
Copy link
Member

@yao-matrix no worries is some tests are failing and are not related to PR changes. Might be just flaky or will be fixed on main by us. From what I see the only CI test affected by PR is this one + need to see if new test passes for all models

tests/models/gemma2/test_modeling_gemma2.py::Gemma2ModelTest::test_assisted_decoding_with_num_logits_to_keep

@yao-matrix
Copy link
Author

@zucchini-nlp , any more comments for me to iterate? Thx.

@zucchini-nlp
Copy link
Member

@yao-matrix no, the only thing is the CI which is failing now. I showed the relevant test in prev comment and if you can add one more test in tests/generation/test_utils.py which would test static cache with assisted generation. That is all actually

At the end you need to run make style to pass CI check on codestyle. Feel free to tag the core maintainer @ ArthurZucker for review when tests are ready and CI is green or tag me if you need help/have questions :)

Comment on lines 116 to 121
@parameterized.expand([(None, True), ("static", False)])
def test_assisted_decoding_with_num_logits_to_keep(self, cache_implementation, return_legacy_cache):
if cache_implementation == "static":
self.skipTest("Gemma2 has HybridCache which is not compatible with assisted decoding StaticCache")
pass

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's not skip entirely, but only the static_cache test, as we still need to check if assisted generation works in Gemma2 :)

Maybe it will be skipped by the model._support_static_cache as I've commented above, but if not we can skip only the test_assisted_decoding_with_num_logits_to_keep_1_static (maybe it's called a bit differently)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i switch to _supports_static_cache to skip the case. For Gemma, it's a bit different, since it's using HybridCache and claims _supports_static_cache = True, I still skip it in model test file. Will remove this skip after enable HybridCache for assisted decoding, I plan to enable it after this PR(pure StaticCache) merged, thx.

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks very nice, but we need to add a compile test to make sure this is compile compatible! The whole point of static cache is -> compile! 🤗

@yao-matrix
Copy link
Author

yao-matrix commented Dec 11, 2024

Looks very nice, but we need to add a compile test to make sure this is compile compatible! The whole point of static cache is -> compile! 🤗

@ArthurZucker i added a test_assisted_decoding_compile case based on test_generate_compile, forward_only test pass for llama, end_to_end test fail for the same reason as Joao commented in test_generate_compile.

@yao-matrix
Copy link
Author

yao-matrix commented Dec 13, 2024

@ArthurZucker @zucchini-nlp , pls let me know any further comments, thx. BTW, checked the failed ci case, not relevant to my changes.

@zucchini-nlp
Copy link
Member

Thanks, re-triggered the tests, let's wait for the core maintainer

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@yao-matrix
Copy link
Author

@ArthurZucker , @zucchini-nlp , I am thinking is it possible we leave this PR in 2024, :).

@yao-matrix
Copy link
Author

@zucchini-nlp @ArthurZucker , any further comments on this?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants